{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# nn.Module 继承 来构造模型\n",
    "\n",
    "`Module`类是`nn`模块里提供的一个模型构造类，是所有神经网络模块的基类，我们可以继承它来定义我们想要的模型。下面继承`Module`类构造本节开头提到的多层感知机。这里定义的`MLP`类重载了`Module`类的`__init__`函数和`forward`函数。它们分别用于创建模型参数和定义前向计算。前向计算也即正向传播。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    # 声明带有模型参数的层，这里声明了两个全连接层\n",
    "    def __init__(self, **kwargs):\n",
    "        # 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数\n",
    "        # 参数，如“模型参数的访问、初始化和共享”一节将介绍的模型参数params\n",
    "        super(MLP, self).__init__(**kwargs)\n",
    "        self.hidden = nn.Linear(28*28, 256)\n",
    "        self.act = nn.ReLU()\n",
    "        self.out = nn.Linear(256,10)\n",
    "        \n",
    "        # 定义模型的前向计算，即如何根据输入x计算返回所需要的模型输出\n",
    "    def forward(self, x):\n",
    "        x = self.act(self.hidden(x))\n",
    "        return self.out(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0836, -0.0942,  0.2857, -0.1599, -0.0923, -0.3792, -0.5406, -0.1677,\n",
      "          0.0389, -0.2380],\n",
      "        [-0.1376,  0.1140, -0.1313,  0.1248,  0.0149, -0.3359, -0.4447, -0.0092,\n",
      "         -0.2264, -0.1507]], grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(2,1,28*28)\n",
    "x = x.view(x.size()[0], -1)\n",
    "net = MLP()\n",
    "out = net(x)\n",
    "print(out)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP(\n",
      "  (hidden): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (act): ReLU()\n",
      "  (out): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意，这里并没有将`Module`类命名为`Layer`（层）或者`Model`（模型）之类的名字，这是因为该类是一个可供自由组建的部件。它的子类既可以是一个层（如PyTorch提供的`Linear`类），又可以是一个模型（如这里定义的`MLP`类），或者是模型的一个部分。我们下面通过两个例子来展示它的灵活性。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 快速构建简单的网络，不需要写forward\n",
    "\n",
    "## 都在nn工具箱下\n",
    "\n",
    "\n",
    "### Squential 自动forward\n",
    "\n",
    "1. Sequential 直接接受nn模块\n",
    ">对于 Sequential的对象，可以 add_modules(继承Module的实例)\n",
    "\n",
    "2. ModuleList 接受nn模块的list\n",
    "\n",
    "> 方便像list 那样 append, extend\n",
    "\n",
    "3. ModulDict  接受字典，方便命名，按名字访问，按名字添加属性，layer\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n",
      "tensor([[-0.5808,  0.0198, -0.2878, -0.3476, -0.2464, -0.1148,  0.4248, -0.3322,\n",
      "         -0.1394,  0.6659],\n",
      "        [-0.5151,  0.1577, -0.0254, -0.1316,  0.3060, -0.0942,  0.3053, -0.3888,\n",
      "          0.3985,  0.2867]], grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "from torch.nn import Sequential\n",
    "net = Sequential(\n",
    "        nn.Linear(784, 256),\n",
    "        nn.ReLU(),\n",
    "        nn.Linear(256, 10), \n",
    "        )\n",
    "print(net)\n",
    "out = net(x)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ModuleList(\n",
      "  (0): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n",
      "tensor([[ 0.1271,  0.3189,  0.4812, -0.2085, -0.0490,  0.4415, -0.0589,  0.3638,\n",
      "         -0.6852,  0.1257],\n",
      "        [-0.1920,  0.1337,  0.2488, -0.3550,  0.2886,  0.6090,  0.1322,  0.4318,\n",
      "         -0.0815,  0.2190]], grad_fn=<AddmmBackward>)\n",
      "ModuleList(\n",
      "  (0): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
      "  (3): Linear(in_features=10, out_features=2, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = nn.ModuleList([nn.Linear(784, 256),\n",
    "        nn.ReLU(),\n",
    "        nn.Linear(256, 10)])\n",
    "print(net)\n",
    "# 下面会报错,因为modellist没有实现forward方法\n",
    "out = x\n",
    "for layer in net:\n",
    "    out = layer(out)\n",
    "print(out)\n",
    "\n",
    "net.append(nn.Linear(10,2))\n",
    "\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ModuleDict(\n",
      "  (act): ReLU()\n",
      "  (linear): Linear(in_features=784, out_features=256, bias=True)\n",
      ")\n",
      "torch.Size([2, 256])\n",
      "ModuleDict(\n",
      "  (act): ReLU()\n",
      "  (linear): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (layer2): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = nn.ModuleDict({\n",
    "    'linear':nn.Linear(28*28, 256),\n",
    "    'act': nn.ReLU()\n",
    "})\n",
    "# 同样 module dict没有forward\n",
    "#out = net(x)\n",
    "print(net)\n",
    "out = x\n",
    "\n",
    "out = net['linear'](out)\n",
    "print(out.shape)\n",
    "\n",
    "#加入新的layer\n",
    "net['layer2'] = nn.Linear(256,10)\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构造复杂的模型\n",
    "\n",
    "虽然上面介绍的这些类可以使模型构造更加简单，且不需要定义`forward`函数，但直接继承`Module`类可以极大地拓展模型构造的灵活性。下面我们构造一个稍微复杂点的网络`FancyMLP`。在这个网络中，我们通过`get_constant`函数创建训练中不被迭代的参数，即常数参数。在前向计算中，除了使用创建的常数参数外，我们还使用`Tensor`的函数和Python的控制流，并多次调用相同的层。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FancyMLP(nn.Module):\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NestMLP(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(NestMLP, self).__init__(**kwargs)\n",
    "        self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU()) \n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "#通过 nnModule 用nn.sequential 可以串联起来\n",
    "net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): NestMLP(\n",
      "    (net): Sequential(\n",
      "      (0): Linear(in_features=40, out_features=30, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "  )\n",
      "  (1): Linear(in_features=30, out_features=20, bias=True)\n",
      "  (2): FancyMLP()\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型参数的初始化，访问、共享\n",
    "\n",
    "\n",
    "一般来说，用nn工具箱中的Layer 都是会自动初始化的，给随机数等，nn.init模块有多种初始化方式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=4, out_features=3, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=3, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import init\n",
    "\n",
    "net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1))  # pytorch已进行默认初始化\n",
    "\n",
    "print(net)\n",
    "X = torch.rand(2, 4)\n",
    "Y = net(X).sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 访问模型参数\n",
    "\n",
    "回忆一下上一节中提到的`Sequential`类与`Module`类的继承关系。对于`Sequential`实例中含模型参数的层，我们可以通过`Module`类的`parameters()`或者`named_parameters`方法来访问所有参数（以迭代器的形式返回），后者除了返回参数`Tensor`外还会返回其名字。下面，访问多层感知机`net`的所有参数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'generator'>\n",
      "0.weight /// torch.Size([3, 4])\n",
      "0.bias /// torch.Size([3])\n",
      "2.weight /// torch.Size([1, 3])\n",
      "2.bias /// torch.Size([1])\n"
     ]
    }
   ],
   "source": [
    "print(type(net.named_parameters()))\n",
    "for name, param in net.named_parameters():\n",
    "    print(name,'///' ,param.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可见返回的名字自动加上了层数的索引作为前缀。\n",
    "我们再来访问`net`中单层的参数。对于使用`Sequential`类构造的神经网络，我们可以通过方括号`[]`来访问网络的任一层。索引0表示隐藏层为`Sequential`实例最先添加的层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weight torch.Size([3, 4]) <class 'torch.nn.parameter.Parameter'>\n",
      "Parameter containing:\n",
      "tensor([[ 0.4225,  0.4102, -0.4574, -0.2837],\n",
      "        [ 0.0826,  0.2120, -0.2627,  0.3030],\n",
      "        [ 0.1647,  0.1645, -0.1407, -0.4159]], requires_grad=True)\n",
      "bias torch.Size([3]) <class 'torch.nn.parameter.Parameter'>\n",
      "Parameter containing:\n",
      "tensor([-0.3954, -0.4933, -0.2842], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "for name, param in net[0].named_parameters():\n",
    "    print(name, param.size(), type(param))\n",
    "    print(param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 普通tensor 和 nn.Parameter\n",
    "因为这里是单层的所以没有了层数索引的前缀。另外返回的`param`的类型为`torch.nn.parameter.Parameter`，其实这是`Tensor`的子类，和`Tensor`不同的是如果一个`Tensor`是`Parameter`，那么它会自动被添加到模型的参数列表里，来看下面这个例子。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weight1 <class 'torch.nn.parameter.Parameter'>\n"
     ]
    }
   ],
   "source": [
    "class MyModel(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(MyModel, self).__init__(**kwargs)\n",
    "        # 只有 nn.Parameter 类型的tensor 才会加入到参数列表中，可以进行求导\n",
    "        self.weight1 = nn.Parameter(torch.rand(20, 20))\n",
    "        self.weight2 = torch.rand(20, 20)\n",
    "    def forward(self, x):\n",
    "        pass\n",
    "    \n",
    "n = MyModel()\n",
    "for name, param in n.named_parameters():\n",
    "    print(name, type(param))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上面的代码中`weight1`在参数列表中但是`weight2`却没在参数列表中。\n",
    "\n",
    "因为`Parameter`是`Tensor`，即`Tensor`拥有的属性它都有，比如可以根据`data`来访问参数数值，用`grad`来访问参数梯度。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 初始化模型参数\n",
    "\n",
    "我们在3.15节（数值稳定性和模型初始化）中提到了PyTorch中`nn.Module`的模块参数都采取了较为合理的初始化策略（不同类型的layer具体采样的哪一种初始化方法的可参考[源代码](https://github.com/pytorch/pytorch/tree/master/torch/nn/modules)）。但我们经常需要使用其他方法来初始化权重。PyTorch的`init`模块里提供了多种预设的初始化方法。在下面的例子中，我们将权重参数初始化成均值为0、标准差为0.01的正态分布随机数，并依然将偏差参数清零。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.weight Parameter containing:\n",
      "tensor([[ 0.0012,  0.0022, -0.0056, -0.0029],\n",
      "        [-0.0029, -0.0102, -0.0020,  0.0151],\n",
      "        [-0.0148, -0.0132,  0.0012,  0.0214]], requires_grad=True)\n",
      "0.bias tensor([0., 0., 0.])\n",
      "2.weight Parameter containing:\n",
      "tensor([[0.0073, 0.0087, 0.0143]], requires_grad=True)\n",
      "2.bias tensor([0.])\n"
     ]
    }
   ],
   "source": [
    "for name, param in net.named_parameters():\n",
    "    if 'weight' in name:\n",
    "        torch.nn.init.normal_(param,0,0.01)\n",
    "        print(name, param)\n",
    "    if 'bias' in name:\n",
    "        init.constant_(param, val=0)\n",
    "        print(name, param.data)\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果只想对某个特定参数进行初始化，我们可以调用`Parameter`类的`initialize`函数，它与`Block`类提供的`initialize`函数的使用方法一致。下例中我们对隐藏层的权重使用Xavier随机初始化方法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  自定义初始化方法\n",
    "\n",
    "有时候我们需要的初始化方法并没有在`init`模块中提供。这时，可以实现一个初始化方法，从而能够像使用其他初始化方法那样使用它。\n",
    "\n",
    "我们还可以通过改变这些参数的`data`来改写模型参数值同时不会影响梯度.\n",
    "\n",
    "\n",
    "## 共享模型参数\n",
    "\n",
    "在有些情况下，我们希望在多个层之间共享模型参数。4.1.3节提到了如何共享模型参数: `Module`类的`forward`函数里多次调用同一个层。此外，如果我们传入`Sequential`的模块是同一个`Module`实例的话参数也是共享的，下面来看一个例子: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=1, out_features=1, bias=False)\n",
      "  (1): Linear(in_features=1, out_features=1, bias=False)\n",
      ")\n",
      "0.weight tensor([[3.]])\n"
     ]
    }
   ],
   "source": [
    "linear = nn.Linear(1, 1, bias=False)\n",
    "net = nn.Sequential(linear, linear) \n",
    "print(net)\n",
    "for name, param in net.named_parameters():\n",
    "    init.constant_(param, val=3)\n",
    "    print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(id(net[0]) == id(net[1]))\n",
    "print(id(net[0].weight) == id(net[1].weight))\n",
    "#内存里其实是一个东西"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.]])\n",
      "tensor([[9.]], grad_fn=<MmBackward>)\n",
      "tensor(9., grad_fn=<SumBackward0>)\n",
      "tensor([[6.]])\n",
      "tensor([[0.]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(1, 1)\n",
    "print(x)\n",
    "print(net(x))\n",
    "\n",
    "y = net(x).sum()\n",
    "\n",
    "print(y)\n",
    "\n",
    "\n",
    "y.backward()\n",
    "\n",
    "print(net[0].weight.grad) # 单次梯度是3，两次所以就是6\n",
    "\n",
    "net[0].weight.grad.zero_() \n",
    "\n",
    "print(net[0].weight.grad) # 单次梯度是3，两次所以就是6\n",
    "#weight.grad 是会累加的"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# 自定义层\n",
    " \n",
    "## 如何使用自己定义的层呢"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 不含模型参数的自定义层\n",
    "\n",
    "我们先介绍如何定义一个不含模型参数的自定义层。事实上，这和4.1节（模型构造）中介绍的使用`Module`类构造模型类似。下面的`CenteredLayer`类通过继承`Module`类自定义了一个将输入减掉均值后输出的层，并将层的计算定义在了`forward`函数里。这个层里不含模型参数。\n",
    "\n",
    "\n",
    "\n",
    "我们可以实例化这个层，然后做前向计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "class CenteredLayer(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(CenteredLayer, self).__init__(**kwargs)\n",
    "    def forward(self, x):\n",
    "        return x - x.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-2., -1.,  0.,  1.,  2.])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer = CenteredLayer()\n",
    "layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MyDense(\n",
      "  (params): ParameterList(\n",
      "      (0): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (2): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (3): Parameter containing: [torch.FloatTensor of size 4x2]\n",
      "  )\n",
      ")\n",
      "tensor([[ 1.3558,  0.7263],\n",
      "        [-3.9452,  0.8231],\n",
      "        [ 0.0279,  9.6577],\n",
      "        [ 3.6184,  7.3969],\n",
      "        [-3.0278, -0.8494],\n",
      "        [ 3.0580, -1.7761],\n",
      "        [ 6.2205,  4.9560],\n",
      "        [-2.5933, -8.6391],\n",
      "        [ 2.7387,  2.3305],\n",
      "        [-1.1401,  3.9867]], grad_fn=<MmBackward>)\n"
     ]
    }
   ],
   "source": [
    "class MyDense(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(MyDense, self).__init__()\n",
    "        \n",
    "        self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(3)])\n",
    "        self.params.append(nn.Parameter(torch.randn(4,2)))\n",
    "    def forward(self, x):\n",
    "        for param in self.params:\n",
    "            x = torch.mm(x, param)\n",
    "        return x\n",
    "\n",
    "net = MyDense()\n",
    "print(net)\n",
    "\n",
    "x = torch.randn(10,4)\n",
    "out = net(x)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "而`ParameterDict`接收一个`Parameter`实例的字典作为输入然后得到一个参数字典，然后可以按照字典的规则使用了。例如使用`update()`新增参数，使用`keys()`返回所有键值，使用`items()`返回所有键值对等等，可参考"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MyDictDense(\n",
      "  (params): ParameterDict(\n",
      "      (layer1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (layer2): Parameter containing: [torch.FloatTensor of size 4x1]\n",
      "      (layer3): Parameter containing: [torch.FloatTensor of size 1x4]\n",
      "      (out): Parameter containing: [torch.FloatTensor of size 4x2]\n",
      "  )\n",
      ")\n",
      "tensor([[-2.5305, -2.1880],\n",
      "        [-0.4013, -0.3470],\n",
      "        [-0.4136, -0.3576],\n",
      "        [ 0.5182,  0.4480],\n",
      "        [-0.5539, -0.4790],\n",
      "        [-0.5887, -0.5090],\n",
      "        [-0.7974, -0.6895],\n",
      "        [-0.7974, -0.6895],\n",
      "        [-0.5581, -0.4826],\n",
      "        [-0.2467, -0.2133]], grad_fn=<MmBackward>)\n"
     ]
    }
   ],
   "source": [
    "class MyDictDense(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(MyDictDense, self).__init__()\n",
    "        \n",
    "        self.params = nn.ParameterDict({ 'layer1': nn.Parameter(torch.randn(4,4)),\n",
    "                                       'layer2': nn.Parameter(torch.randn(4,1)),\n",
    "                                       'layer3': nn.Parameter(torch.randn(1,4))})\n",
    "        \n",
    "        self.params.update({'out': nn.Parameter(torch.randn(4,2))})\n",
    "    def forward(self, x):\n",
    "        for layer_name in self.params:\n",
    "\n",
    "            x = torch.mm(x, self.params[layer_name])\n",
    "        return x\n",
    "\n",
    "net = MyDictDense()\n",
    "print(net)\n",
    "\n",
    "x = torch.randn(10,4)\n",
    "out = net(x)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 同样可以加入sequential\n",
    "可以通过`Module`类自定义神经网络中的层，从而可以被重复调用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): MyDictDense(\n",
      "    (params): ParameterDict(\n",
      "        (layer1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (layer2): Parameter containing: [torch.FloatTensor of size 4x1]\n",
      "        (layer3): Parameter containing: [torch.FloatTensor of size 1x4]\n",
      "        (out): Parameter containing: [torch.FloatTensor of size 4x2]\n",
      "    )\n",
      "  )\n",
      "  (1): Linear(in_features=2, out_features=4, bias=True)\n",
      "  (2): MyDense(\n",
      "    (params): ParameterList(\n",
      "        (0): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (2): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (3): Parameter containing: [torch.FloatTensor of size 4x2]\n",
      "    )\n",
      "  )\n",
      ")\n",
      "tensor([[-423.7537, -281.0964],\n",
      "        [-362.3016, -239.7750],\n",
      "        [  48.4405,   36.4148],\n",
      "        [-399.2125, -264.5945],\n",
      "        [  70.8420,   51.4779],\n",
      "        [  61.0302,   44.8803],\n",
      "        [-131.3318,  -84.4671],\n",
      "        [-190.2179, -124.0631],\n",
      "        [ -37.0477,  -21.0689],\n",
      "        [  -1.8290,    2.6127]], grad_fn=<MmBackward>)\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    MyDictDense(),\n",
    "    nn.Linear(2,4),\n",
    "    MyDense(),\n",
    ")\n",
    "\n",
    "print(net)\n",
    "print(net(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
